package middleware

import (
	"bytes"
	"crypto/rand"
	"fmt"
	"github.com/cookieY/yee"
	"github.com/opentracing/opentracing-go"
	"github.com/opentracing/opentracing-go/ext"
	"github.com/uber/jaeger-client-go/config"
	"io"
	"io/ioutil"
	"net/http"
	"time"
)

const defaultComponentName = "yee"

type (
	TraceConfig struct {
		// OpenTracing Tracer instance which should be got before
		Tracer        opentracing.Tracer
		ComponentName string
		IsBodyDump    bool
		LimitHTTPBody bool
		LimitSize     int
	}
	responseDumper struct {
		http.ResponseWriter

		mw  io.Writer
		buf *bytes.Buffer
	}
)

func New(e *yee.Core) io.Closer {
	// Add Opentracing instrumentation
	defcfg := config.Configuration{
		ServiceName: "echo-tracer",
		Sampler: &config.SamplerConfig{
			Type:  "const",
			Param: 1,
		},
		Reporter: &config.ReporterConfig{
			LogSpans:            true,
			BufferFlushInterval: 1 * time.Second,
		},
	}
	cfg, err := defcfg.FromEnv()
	if err != nil {
		panic("Could not parse Jaeger env vars: " + err.Error())
	}
	tracer, closer, err := cfg.NewTracer()
	if err != nil {
		panic("Could not initialize jaeger tracer: " + err.Error())
	}

	opentracing.SetGlobalTracer(tracer)
	e.Use(TraceWithConfig(TraceConfig{
		Tracer: tracer,
	}))
	return closer
}

func TraceWithConfig(config TraceConfig) yee.HandlerFunc {
	if config.Tracer == nil {
		panic("yee: trace middleware requires opentracing tracer")
	}
	if config.ComponentName == "" {
		config.ComponentName = defaultComponentName
	}

	return func(c yee.Context) error {
		req := c.Request()
		opname := "HTTP " + req.Method + " URL: " + c.Path()
		realIP := c.RemoteIP()
		requestID := getRequestID(c) // request-id generated by reverse-proxy

		var sp opentracing.Span
		var err error

		ctx, err := config.Tracer.Extract(
			opentracing.HTTPHeaders,
			opentracing.HTTPHeadersCarrier(req.Header),
		)

		if err != nil {
			sp = config.Tracer.StartSpan(opname)
		} else {
			sp = config.Tracer.StartSpan(opname, ext.RPCServerOption(ctx))
		}
		defer sp.Finish()

		ext.HTTPMethod.Set(sp, req.Method)
		ext.HTTPUrl.Set(sp, req.URL.String())
		ext.Component.Set(sp, config.ComponentName)
		sp.SetTag("client_ip", realIP)
		sp.SetTag("request_id", requestID)

		// Dump request & response body
		var respDumper *responseDumper
		if config.IsBodyDump {
			// request
			reqBody := []byte{}
			if c.Request().Body != nil {
				reqBody, _ = ioutil.ReadAll(c.Request().Body)

				if config.LimitHTTPBody {
					sp.LogKV("http.req.body", limitString(string(reqBody), config.LimitSize))
				} else {
					sp.LogKV("http.req.body", string(reqBody))
				}
			}

			req.Body = ioutil.NopCloser(bytes.NewBuffer(reqBody)) // reset original request body

			// response
			respDumper = newResponseDumper(c)
			c.Response().Override(respDumper.ResponseWriter)
		}

		// setup request context - add opentracing span
		req = req.WithContext(opentracing.ContextWithSpan(req.Context(), sp))
		c.SetRequest(req)

		// call next middleware / controller
		c.Next()
		if err != nil {
			c.Logger().Error(err) // call custom registered error handler
		}

		status := c.Response().Status()
		ext.HTTPStatusCode.Set(sp, uint16(status))

		if err != nil {
			logError(sp, err)
		}

		// Dump response body
		if config.IsBodyDump {
			if config.LimitHTTPBody {
				sp.LogKV("http.resp.body", limitString(respDumper.GetResponse(), config.LimitSize))
			} else {
				sp.LogKV("http.resp.body", respDumper.GetResponse())
			}
		}

		return nil // error was already processed with ctx.Error(err)
	}
}

func getRequestID(ctx yee.Context) string {
	requestID := ctx.Request().Header.Get(yee.HeaderXRequestID) // request-id generated by reverse-proxy
	if requestID == "" {
		requestID = generateToken() // missed request-id from proxy, we generate it manually
	}
	return requestID
}

func generateToken() string {
	b := make([]byte, 16)
	rand.Read(b)
	return fmt.Sprintf("%x", b)
}

func limitString(str string, size int) string {
	if len(str) > size {
		return str[:size/2] + "\n---- skipped ----\n" + str[len(str)-size/2:]
	}

	return str
}

func newResponseDumper(resp yee.Context) *responseDumper {
	buf := new(bytes.Buffer)
	return &responseDumper{
		ResponseWriter: resp.Response().Writer(),
		mw:             io.MultiWriter(resp.Response().Writer(), buf),
		buf:            buf,
	}
}

func (d *responseDumper) Write(b []byte) (int, error) {
	return d.mw.Write(b)
}

func (d *responseDumper) GetResponse() string {
	return d.buf.String()
}

func logError(span opentracing.Span, err error) {
	span.LogKV("error.message", err.Error())
	span.SetTag("error", true)
}
